
import numpy as np
#import matplotlib.pyplot as plt
from scipy import interpolate
from scipy import stats
    
def data_input():

#%% read 3D EC values for training and validation I
#    dat = np.loadtxt('EC_layers.txt',delimiter=',')                           # for training data
    dat = np.loadtxt('EC_validation.txt',delimiter=',')                       # validation data I
    ori_ec = np.zeros((200,200,10)) 
    for i in range(10):
        ori_ec[:,:,9-i] = np.reshape(dat[:,2+i],(200,200)).T
    ori_ec[ori_ec>1000]=1000

#%% read 3D EC values for  validation II   
#    dat = np.loadtxt('EC_validation_II.txt',delimiter=',')                      # validation data II
#    ori_ec = np.zeros((200,200,10))
#    for i in range(4):
#        ori_ec[:,:,i] = np.reshape(dat[:,i],(200,200))
#    ori_ec[ori_ec>1000]=1000    
#    ori_ec[ori_ec<1]=1.0 
#    
#%% normalizations 
    ec=np.log10(ori_ec)
    P=(ec-ec.min())/(3.0-ec.min())
    image_out=np.float32(P.reshape(1,200,200,10,1))
    
   

#%% read MRVBF for training and validation data I, convert it to regular grid of 800 by 800 with each cell of 100 m
 #   dat = np.loadtxt('MRVBF.txt',delimiter=',',skiprows=1)                     # train data 
    dat = np.loadtxt('mrvbf_validation.txt',delimiter=',',skiprows=1)          # validation data I
    m=dat[:,2].reshape(827,828).T
    x=dat[:,3].reshape(827,828).T
    y=dat[:,4].reshape(827,828).T
    m1=m[1:,:]
    x=x[1:,:]
    y=y[1:,:]
    
    
#%% read MRVBF for training and validation data I,     
#    dat = np.loadtxt('mrvbf_validation_II.txt',delimiter=',',skiprows=1)       # validation data
#    m=dat[:,2].reshape(828,827).T
#    x=dat[:,3].reshape(828,827).T
#    y=dat[:,4].reshape(828,827).T
#    m1=m[:,:-1]
#    x=x[:,:-1]
#    y=y[:,:-1]
        

#%% interpolations 
    mf=np.zeros([827,800])
    
    for i in range(827):
        f1=interpolate.interp1d(x[i,:],m1[i,:],kind='linear')
        Xnew=np.linspace(x.min(),x.max(),800)
        mf[i,:]=f1(Xnew)
    
    mff=np.zeros([800,800])
    for i in range(800):
        f2=interpolate.interp1d(y[:,i],mf[:,i],kind='linear')
        Ynew=np.linspace(y.min(),y.max(),800)
        mff[:,i]=f2(Ynew)

#%% normalization    
    mff[mff>7]=7
    
    mff[mff<0]=0
    m3=mff/7.0
    mrvbf=np.flipud(m3)
 
    return mrvbf, image_out


def data_combine(mrvbf):
    
    Image_IN=np.zeros((800,800,10))
    
    for i in range (10):
        lower,upper=-0.1, 0.1
        mu,sigma=0.0, 0.2
        wnoise=stats.truncnorm((lower-mu)/sigma,(upper-mu)/sigma,loc=mu,scale=sigma).rvs(800*800).reshape(800,800)
        q=mrvbf+wnoise
#        q=mrvbf
        Image_IN[:,:,i]=q

    image_in=np.float32(Image_IN.reshape(1,800,800,10,1))
    
    return image_in

#def data_combine(mrvbf):
#    
#    lower,upper=-0.1, 0.1
#    mu,sigma=0.0, 0.2
#    
##    lower,upper=-0.5, 0.5
##    mu,sigma=0.0, 1.0    
#    
##    lower,upper=-1.0, 1.0
##    mu,sigma=0.0, 2.0
#
#    wnoise=stats.truncnorm((lower-mu)/sigma,(upper-mu)/sigma,loc=mu,scale=sigma).rvs(800*800).reshape(800,800)
#    
#    Image_IN=np.zeros((800,800,2))
#    Image_IN[:,:,0]=mrvbf
#    Image_IN[:,:,1]=wnoise
#    image_in=np.float32(Image_IN.reshape(1,800,800,2,1))
#    
#    return image_in